{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "from torch import nn\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "import re\n",
    "import json\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from collections import Counter\n",
    "from pensmodule.Generator.train import *\n",
    "import json\n",
    "import matplotlib.pyplot as plt"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"\n",
    "device = torch.device('cuda:0')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Prepare"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "- **Config & Data**"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "from data import *\n",
    "with open('config.json') as f:\n",
    "    config = json.load(f)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "sources = np.load('../../data2/sources.npy')\n",
    "target_inputs = np.load('../../data2/target_inputs.npy')\n",
    "target_outputs = np.load('../../data2/target_outputs.npy')\n",
    "embedding_matrix = np.load('../../data2/embedding_matrix2.npy')\n",
    "with open('../../data2/dict.pkl', 'rb') as f:\n",
    "    news_index,category_dict,word_dict = pickle.load(f)\n",
    "index2word = {}\n",
    "for k,v in word_dict.items():\n",
    "    index2word[v] = k\n",
    "print(len(word_dict),embedding_matrix.shape)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "- **Model**"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "from pensmodule.Generator import HeadlineGen\n",
    "model = HeadlineGen(config['model'], embedding_matrix, index2word, device, pointer_gen=True).to(device)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "from pensmodule.UserEncoder import NRMS\n",
    "\n",
    "usermodel = NRMS(embedding_matrix)\n",
    "usermodel.load_state_dict(torch.load('../../runs/userencoder/NAML-2.pkl'))\n",
    "usermodel = usermodel.to(device)\n",
    "usermodel.eval()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "- **Load Trainer**"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "# here you can set different modes for computing coverage scores\n",
    "trainer = Trainer(config, model, usermodel, device, mode=4, experiment_name='exp')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Pretrain Seq2seq model"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "**Noted:**\\\n",
    "**For a fair comparison, here we advise using other source and targets (from your own collected news datasets) for pretraining.**"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "global_user_embed = np.load('../../data2/global_user_embed2.npy')\n",
    "global_user_embed = torch.as_tensor(global_user_embed, device=device).float()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "trainer._init_optimizer()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "for epoch in range(1,4):\n",
    "    print('epoch:', epoch)\n",
    "    s_dset = Seq2SeqDataset(sources, target_inputs, target_outputs)\n",
    "    train_iter = DataLoader(s_dset, batch_size=128, shuffle=True)\n",
    "    \n",
    "    trainer.pretrain(train_iter, global_user_embed)\n",
    "    trainer.save_checkpoint(tag='pretrain_epoch_'+str(epoch))"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Train Personalized Generator"
   ],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "- **data loader**"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "with open('../../data2/TrainUsers.pkl', 'rb') as f:\n",
    "    TrainUsers = pickle.load(f)\n",
    "with open('../../data2/TrainSamples.pkl', 'rb') as f:\n",
    "    TrainSamples = pickle.load(f)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "news_scoring = np.load('../../data2/news_scoring2.npy')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "i_dset = ImpressionDataset(news_scoring, sources, target_inputs, target_outputs, TrainUsers, TrainSamples)\n",
    "data_loader = DataLoader(i_dset, batch_size=128, shuffle=True)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "- **train**"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "trainer.load_checkpoint(tag='pretrain_epoch_3')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "trainer._init_evaluator_()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "trainer._init_context_()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "optimizer_params={'lr': 0.000001}\n",
    "scheduler_params={'step_size': 200, 'gamma': 0.98}\n",
    "trainer._init_optimizer(optimizer_params=optimizer_params,scheduler_params=scheduler_params)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "rewards = trainer.train(data_loader,train_option='a2c', tag='mod4')"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "x = np.arange(len(rewards))\n",
    "fig = plt.gcf()\n",
    "plt.plot(x[:3000],rewards[:3000])\n",
    "\n",
    "plt.xlabel(\"step\")\n",
    "plt.ylabel(\"reward\")\n",
    "plt.show()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Test"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "with open('../../data2/TestUsers.pkl', 'rb') as f:\n",
    "    TestUsers = pickle.load(f)\n",
    "with open('../../data2/TestSamples.pkl', 'rb') as f:\n",
    "    TestSamples = pickle.load(f)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "from pensmodule.Generator import *\n",
    "model_path = '../runs/seq2seq/cov/checkpoint_train_mod4_step_2000.pth'\n",
    "model = load_model_from_ckpt(model_path).to(device)\n",
    "model.eval()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "from pensmodule.UserEncoder import NRMS\n",
    "\n",
    "usermodel = NRMS(embedding_matrix)\n",
    "usermodel.load_state_dict(torch.load('../../runs/userencoder/NAML-2.pkl'))\n",
    "usermodel = usermodel.to(device)\n",
    "usermodel.eval()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "i_dset = TestImpressionDataset(news_scoring, sources, TestUsers, TestSamples)\n",
    "test_iter = DataLoader(i_dset, batch_size=16, shuffle=False)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "from pensmodule.Generator.eval import predict\n",
    "refs, hyps, scores1, scores2, scoresf = predict(usermodel, model, test_iter, device, index2word, beam=False, beam_size=3, eos_id=2)\n",
    "# refs, hyps, scores1, scores2, scoresf = predict(usermodel, model, test_iter, device, index2word, beam=True, beam_size=3, eos_id=2)"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [
    "scores1.mean(), scores2.mean(), scoresf.mean()"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "source": [],
   "outputs": [],
   "metadata": {}
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.3 64-bit ('py3': conda)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "interpreter": {
   "hash": "f1ea14d0bd9c7a0f4f2d255c0662c7e1119328c505d1c17d9d8f159415dfcf69"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}